Author

Zeju Li

Vector Quantized Variational AutoEncoder (VQ-VAE)

Adapted from https://github.com/Jackson-Kang/Pytorch-VAE-tutorial

Consider to download this Jupyter Notebook and run locally, or test it with Colab.

Download Open In Colab

Install required packages

!pip install torch
!pip install matplotlib
Requirement already satisfied: torch in /opt/anaconda3/lib/python3.12/site-packages (2.8.0)
Requirement already satisfied: filelock in /opt/anaconda3/lib/python3.12/site-packages (from torch) (3.13.1)
Requirement already satisfied: typing-extensions>=4.10.0 in /opt/anaconda3/lib/python3.12/site-packages (from torch) (4.15.0)
Requirement already satisfied: setuptools in /opt/anaconda3/lib/python3.12/site-packages (from torch) (75.1.0)
Requirement already satisfied: sympy>=1.13.3 in /opt/anaconda3/lib/python3.12/site-packages (from torch) (1.14.0)
Requirement already satisfied: networkx in /opt/anaconda3/lib/python3.12/site-packages (from torch) (3.3)
Requirement already satisfied: jinja2 in /opt/anaconda3/lib/python3.12/site-packages (from torch) (3.1.4)
Requirement already satisfied: fsspec in /opt/anaconda3/lib/python3.12/site-packages (from torch) (2024.6.1)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /opt/anaconda3/lib/python3.12/site-packages (from sympy>=1.13.3->torch) (1.3.0)
Requirement already satisfied: MarkupSafe>=2.0 in /opt/anaconda3/lib/python3.12/site-packages (from jinja2->torch) (2.1.3)
Requirement already satisfied: matplotlib in /opt/anaconda3/lib/python3.12/site-packages (3.9.2)
Requirement already satisfied: contourpy>=1.0.1 in /opt/anaconda3/lib/python3.12/site-packages (from matplotlib) (1.2.0)
Requirement already satisfied: cycler>=0.10 in /opt/anaconda3/lib/python3.12/site-packages (from matplotlib) (0.11.0)
Requirement already satisfied: fonttools>=4.22.0 in /opt/anaconda3/lib/python3.12/site-packages (from matplotlib) (4.51.0)
Requirement already satisfied: kiwisolver>=1.3.1 in /opt/anaconda3/lib/python3.12/site-packages (from matplotlib) (1.4.4)
Requirement already satisfied: numpy>=1.23 in /opt/anaconda3/lib/python3.12/site-packages (from matplotlib) (1.26.4)
Requirement already satisfied: packaging>=20.0 in /opt/anaconda3/lib/python3.12/site-packages (from matplotlib) (24.1)
Requirement already satisfied: pillow>=8 in /opt/anaconda3/lib/python3.12/site-packages (from matplotlib) (10.4.0)
Requirement already satisfied: pyparsing>=2.3.1 in /opt/anaconda3/lib/python3.12/site-packages (from matplotlib) (3.1.2)
Requirement already satisfied: python-dateutil>=2.7 in /opt/anaconda3/lib/python3.12/site-packages (from matplotlib) (2.9.0.post0)
Requirement already satisfied: six>=1.5 in /opt/anaconda3/lib/python3.12/site-packages (from python-dateutil>=2.7->matplotlib) (1.16.0)
# Install required packages
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np

from tqdm import tqdm
from torchvision.utils import save_image, make_grid
# Model Hyperparameters

dataset_path = '~/datasets'

DEVICE = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

batch_size = 100

# note that VQ-VAE enbmeds spatial information, thus needs a CNN
# the latent space also perserve spatial information

input_dim = 1
hidden_dim = 32
latent_dim = 16

n_embeddings= 32 # the length of the codebook
output_dim = 1
commitment_beta = 0.25

lr = 2e-4

epochs = 2 # just for illusration...

print_step = 100

Step 1. Load (or download) Dataset

from torchvision.datasets import MNIST
import torchvision.transforms as transforms
from torch.utils.data import DataLoader


mnist_transform = transforms.Compose([
        transforms.ToTensor(),
])

kwargs = {'num_workers': 1, 'pin_memory': True} 

train_dataset = MNIST(dataset_path, transform=mnist_transform, train=True, download=True)
test_dataset  = MNIST(dataset_path, transform=mnist_transform, train=False, download=True)

train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, **kwargs)
test_loader  = DataLoader(dataset=test_dataset,  batch_size=batch_size, shuffle=False, **kwargs)

Step 2. Define our model: Vector Quantized Variational AutoEncoder (VQ-VAE)

class Encoder(nn.Module):
    
    def __init__(self, input_dim, hidden_dim, output_dim, kernel_size=(4, 4, 3, 1), stride=2):
        super(Encoder, self).__init__()
        
        kernel_1, kernel_2, kernel_3, kernel_4 = kernel_size
        
        self.strided_conv_1 = nn.Conv2d(input_dim, hidden_dim, kernel_1, stride, padding=1)
        self.strided_conv_2 = nn.Conv2d(hidden_dim, hidden_dim, kernel_2, stride, padding=1)
        
        self.residual_conv_1 = nn.Conv2d(hidden_dim, hidden_dim, kernel_3, padding=1)
        self.residual_conv_2 = nn.Conv2d(hidden_dim, hidden_dim, kernel_4, padding=0)
        
        self.proj = nn.Conv2d(hidden_dim, output_dim, kernel_size=1)
        
    def forward(self, x):
        
        x = self.strided_conv_1(x)
        x = self.strided_conv_2(x)
        
        x = F.relu(x)
        y = self.residual_conv_1(x)
        y = y+x
        
        x = F.relu(y)
        y = self.residual_conv_2(x)
        y = y+x
        
        y = self.proj(y)
        return y
class VQEmbeddingEMA(nn.Module):
    def __init__(self, n_embeddings, embedding_dim, commitment_cost=0.25, decay=0.999, epsilon=1e-5):
        super(VQEmbeddingEMA, self).__init__()
        self.commitment_cost = commitment_cost
        self.decay = decay
        self.epsilon = epsilon
        
        init_bound = 1 / n_embeddings
        embedding = torch.Tensor(n_embeddings, embedding_dim)
        embedding.uniform_(-init_bound, init_bound)
        self.register_buffer("embedding", embedding)
        self.register_buffer("ema_count", torch.zeros(n_embeddings))
        self.register_buffer("ema_weight", self.embedding.clone())

    def encode(self, x):
        M, D = self.embedding.size()
        x_flat = x.detach().reshape(-1, D)

        distances = (-torch.cdist(x_flat, self.embedding, p=2)) ** 2

        indices = torch.argmin(distances.float(), dim=-1)
        quantized = F.embedding(indices, self.embedding)
        quantized = quantized.view_as(x)
        return quantized, indices.view(x.size(0), x.size(1))
    
    def retrieve_random_codebook(self, random_indices):
        quantized = F.embedding(random_indices, self.embedding)
        quantized = quantized.transpose(1, 3)
        
        return quantized

    def forward(self, x):
        M, D = self.embedding.size()
        x_flat = x.detach().reshape(-1, D)
        
        distances = (-torch.cdist(x_flat, self.embedding, p=2)) ** 2

        indices = torch.argmin(distances.float(), dim=-1)
        encodings = F.one_hot(indices, M).float()
        quantized = F.embedding(indices, self.embedding)
        quantized = quantized.view_as(x)
        
        if self.training:
            self.ema_count = self.decay * self.ema_count + (1 - self.decay) * torch.sum(encodings, dim=0)
            n = torch.sum(self.ema_count)
            self.ema_count = (self.ema_count + self.epsilon) / (n + M * self.epsilon) * n

            dw = torch.matmul(encodings.t(), x_flat)
            self.ema_weight = self.decay * self.ema_weight + (1 - self.decay) * dw
            self.embedding = self.ema_weight / self.ema_count.unsqueeze(-1)

        codebook_loss = F.mse_loss(x.detach(), quantized)
        e_latent_loss = F.mse_loss(x, quantized.detach())
        commitment_loss = self.commitment_cost * e_latent_loss

        quantized = x + (quantized - x).detach()

        avg_probs = torch.mean(encodings, dim=0)
        perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))

        return quantized, commitment_loss, codebook_loss, perplexity
class Decoder(nn.Module):
    
    def __init__(self, input_dim, hidden_dim, output_dim, kernel_sizes=(1, 3, 2, 2), stride=2):
        super(Decoder, self).__init__()
        
        kernel_1, kernel_2, kernel_3, kernel_4 = kernel_sizes
        
        self.in_proj = nn.Conv2d(input_dim, hidden_dim, kernel_size=1)
        
        self.residual_conv_1 = nn.Conv2d(hidden_dim, hidden_dim, kernel_1, padding=0)
        self.residual_conv_2 = nn.Conv2d(hidden_dim, hidden_dim, kernel_2, padding=1)
        
        self.strided_t_conv_1 = nn.ConvTranspose2d(hidden_dim, hidden_dim, kernel_3, stride, padding=0)
        self.strided_t_conv_2 = nn.ConvTranspose2d(hidden_dim, output_dim, kernel_4, stride, padding=0)
        
    def forward(self, x):

        x = self.in_proj(x)
        
        y = self.residual_conv_1(x)
        y = y+x
        x = F.relu(y)
        
        y = self.residual_conv_2(x)
        y = y+x
        y = F.relu(y)
        
        y = self.strided_t_conv_1(y)
        y = self.strided_t_conv_2(y)
        
        return y
class Model(nn.Module):
    def __init__(self, Encoder, Codebook, Decoder):
        super(Model, self).__init__()
        self.encoder = Encoder
        self.codebook = Codebook
        self.decoder = Decoder
                
    def forward(self, x):
        z = self.encoder(x)
        z_quantized, commitment_loss, codebook_loss, perplexity = self.codebook(z)
        x_hat = self.decoder(z_quantized)
        
        return x_hat, commitment_loss, codebook_loss, perplexity
encoder = Encoder(input_dim=input_dim, hidden_dim=hidden_dim, output_dim=latent_dim)
codebook = VQEmbeddingEMA(n_embeddings=n_embeddings, embedding_dim=latent_dim)
decoder = Decoder(input_dim=latent_dim, hidden_dim=hidden_dim, output_dim=output_dim)

model = Model(Encoder=encoder, Codebook=codebook, Decoder=decoder).to(DEVICE)

Step 3. Define Loss function (reprod. loss) and optimizer

from torch.optim import Adam

mse_loss = nn.MSELoss()

optimizer = Adam(model.parameters(), lr=lr)

Step 4. Train Vector Quantized Variational AutoEncoder (VQ-VAE)

print("Start training VQ-VAE...")
model.train()

for epoch in range(epochs):
    overall_loss = 0
    for batch_idx, (x, _) in enumerate(train_loader):
        x = x.to(DEVICE)

        optimizer.zero_grad()

        x_hat, commitment_loss, codebook_loss, perplexity = model(x)
        recon_loss = mse_loss(x_hat, x)
        
        loss =  recon_loss + commitment_loss * commitment_beta + codebook_loss
                
        loss.backward()
        optimizer.step()
        
        if batch_idx % print_step ==0: 
            print("epoch:", epoch + 1, "(", batch_idx + 1, ") recon_loss:", recon_loss.item(), " perplexity: ", perplexity.item(), 
              " commit_loss: ", commitment_loss.item(), "\n\t codebook loss: ", codebook_loss.item(), " total_loss: ", loss.item(), "\n")
    
print("Finish!!")
Start training VQ-VAE...
epoch: 1 ( 1 ) recon_loss: 0.14143875241279602  perplexity:  10.034667015075684  commit_loss:  0.00490464735776186 
     codebook loss:  0.01961858943104744  total_loss:  0.1622834950685501 

epoch: 1 ( 101 ) recon_loss: 0.05592184141278267  perplexity:  22.171672821044922  commit_loss:  0.10291137546300888 
     codebook loss:  0.4116455018520355  total_loss:  0.49329519271850586 

epoch: 1 ( 201 ) recon_loss: 0.035925429314374924  perplexity:  25.494213104248047  commit_loss:  0.16113601624965668 
     codebook loss:  0.6445440649986267  total_loss:  0.7207534909248352 

epoch: 1 ( 301 ) recon_loss: 0.028422372415661812  perplexity:  27.116924285888672  commit_loss:  0.1424388438463211 
     codebook loss:  0.5697553753852844  total_loss:  0.6337874531745911 

epoch: 1 ( 401 ) recon_loss: 0.02394694648683071  perplexity:  27.800132751464844  commit_loss:  0.12819147109985352 
     codebook loss:  0.5127658843994141  total_loss:  0.5687606930732727 

epoch: 1 ( 501 ) recon_loss: 0.023408997803926468  perplexity:  28.292924880981445  commit_loss:  0.12635277211666107 
     codebook loss:  0.5054110884666443  total_loss:  0.5604082942008972 

epoch: 2 ( 1 ) recon_loss: 0.021908951923251152  perplexity:  28.13805389404297  commit_loss:  0.1150847002863884 
     codebook loss:  0.4603388011455536  total_loss:  0.5110189318656921 

epoch: 2 ( 101 ) recon_loss: 0.020374955609440804  perplexity:  28.325942993164062  commit_loss:  0.10786083340644836 
     codebook loss:  0.43144333362579346  total_loss:  0.4787834882736206 

epoch: 2 ( 201 ) recon_loss: 0.020226318389177322  perplexity:  28.37371063232422  commit_loss:  0.11535041779279709 
     codebook loss:  0.46140167117118835  total_loss:  0.5104656219482422 

epoch: 2 ( 301 ) recon_loss: 0.01963076926767826  perplexity:  28.430225372314453  commit_loss:  0.10961394757032394 
     codebook loss:  0.4384557902812958  total_loss:  0.4854900538921356 

epoch: 2 ( 401 ) recon_loss: 0.018267327919602394  perplexity:  28.544404983520508  commit_loss:  0.1052127480506897 
     codebook loss:  0.4208509922027588  total_loss:  0.46542149782180786 

epoch: 2 ( 501 ) recon_loss: 0.018102342262864113  perplexity:  28.274341583251953  commit_loss:  0.1050005704164505 
     codebook loss:  0.420002281665802  total_loss:  0.4643547534942627 

Finish!!
/opt/anaconda3/lib/python3.12/site-packages/torch/utils/data/dataloader.py:684: UserWarning: 'pin_memory' argument is set as true but not supported on MPS now, then device pinned memory won't be used.
  warnings.warn(warn_msg)

Step 5. Evaluate the model

import matplotlib.pyplot as plt
def draw_sample_image(x, postfix):
  
    plt.figure(figsize=(8,8))
    plt.axis("off")
    plt.title("Visualization of {}".format(postfix))
    plt.imshow(np.transpose(make_grid(x.detach().cpu(), padding=2, normalize=True), (1, 2, 0)))
model.eval()

with torch.no_grad():

    for batch_idx, (x, _) in enumerate(tqdm(test_loader)):

        x = x.to(DEVICE)
        x_hat, commitment_loss, codebook_loss, perplexity = model(x)
 
        print("perplexity: ", perplexity.item(),"commit_loss: ", commitment_loss.item(), "  codebook loss: ", codebook_loss.item())
        break
  0%|                                                                                                                                   | 0/100 [00:00<?, ?it/s]/opt/anaconda3/lib/python3.12/site-packages/torch/utils/data/dataloader.py:684: UserWarning: 'pin_memory' argument is set as true but not supported on MPS now, then device pinned memory won't be used.
  warnings.warn(warn_msg)
  0%|                                                                                                                                   | 0/100 [00:01<?, ?it/s]
perplexity:  27.08778190612793 commit_loss:  0.09418725222349167   codebook loss:  0.3767490088939667

Let’s visualize test samples alongside their reconstruction.

def show_image_grid_detailed(x, nrows=10, ncols=10):
    x = x.view(-1, 28, 28)
    
    fig, axes = plt.subplots(nrows, ncols, figsize=(12, 12))
    
    for i in range(nrows):
        for j in range(ncols):
            idx = i * ncols + j
            if idx < len(x):
                axes[i, j].imshow(x[idx].cpu().numpy(), cmap='gray')
            axes[i, j].axis('off')
    
    plt.tight_layout()
    plt.show()
show_image_grid_detailed(x)

show_image_grid_detailed(x_hat)

Step 6. Generate samples via random codes

def random_sample_image(codebook, decoder, indices_shape):
    
    random_indices = torch.floor(torch.rand(indices_shape) * n_embeddings).long().to(DEVICE) # each pixel samples seperately
    codes = codebook.retrieve_random_codebook(random_indices)
    x_hat = decoder(codes.to(DEVICE))

    return x_hat
x_hat_randn = random_sample_image(codebook, decoder, indices_shape=(100, 7, 7))
show_image_grid_detailed(x_hat_randn.detach())

Back to top

Reuse

Citation

For attribution, please cite this work as:
Li, Zeju. n.d. “Vector Quantized Variational AutoEncoder (VQ-VAE).” https://zerojumpline.github.io//teaching/2025-08-08-Pattern Recognition/code_4_vqvae.html.